{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Potentially useful links\n",
    "- https://towardsdatascience.com/monte-carlo-learning-b83f75233f92\n",
    "- https://www.kth.se/social/files/58b941d5f276542843812288/RL04-Monte-Carlo.pdf\n",
    "- https://medium.com/@zsalloum/q-vs-v-in-reinforcement-learning-the-easy-way-9350e1523031\n",
    "- https://medium.com/@zsalloum/monte-carlo-in-reinforcement-learning-the-easy-way-564c53010511\n",
    "- https://medium.com/@pedrohbtp/ai-monte-carlo-tree-search-mcts-49607046b204"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from random import randrange\n",
    "import numpy as np\n",
    "from Gridworld import Gridworld, DEFAULT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Monte Carlo Algorithm\n",
    "\n",
    "```\n",
    "Intialise random policy with (s) = rand(a)\n",
    "Initialise Q table with (s,a) = 0 \n",
    "Initialise returns (s,a) = []\n",
    "\n",
    "for N episodes\n",
    "    state_action_returns = play_game(policy)\n",
    "    \n",
    "    for state_action_returns\n",
    "        if not seen\n",
    "            add state action to returns\n",
    "            new Q value for (s a) = average (returns (s, a))\n",
    "    \n",
    "    for state in policy\n",
    "        max action for each Q(state)\n",
    "        policy for that state = Max action\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [],
   "source": [
    "def greed_action(action, state):\n",
    "    if (random.random() > EPSILON):\n",
    "        return action\n",
    "    else:\n",
    "        return random.choice([ action for action, reward in g.get_valid_moves(state) ])\n",
    "\n",
    "def play_game(policy):\n",
    "    g = Gridworld(DEFAULT)\n",
    "    state_action_rewards = []\n",
    "    \n",
    "    state = (2,0)\n",
    "    action = greed_action(policy[state], state)\n",
    "    \n",
    "    state_action_rewards.append((state, action, 0))\n",
    "\n",
    "    \n",
    "    while True:\n",
    "        probs = g.transition_probabilities(action,state)\n",
    "        \n",
    "        # next_state \n",
    "        idx = np.random.choice(len(probs), p=[p for p, r, n_s in probs])\n",
    "        _, reward, next_state = probs[idx]\n",
    "        state = next_state\n",
    "        \n",
    "        \n",
    "        if state in g.terminal_states:\n",
    "            state_action_rewards.append((state, None, reward))\n",
    "            break\n",
    "        else:\n",
    "            action = greed_action(policy[state], state)\n",
    "            state_action_rewards.append((state, action, reward))\n",
    "        \n",
    "        state = next_state\n",
    "\n",
    "    return state_action_rewards\n",
    "\n",
    "def calculate_returns(state_action_rewards):\n",
    "    '''\n",
    "    Takes an array of (s[t-1], a[t-1], r)\n",
    "    Returns an array of (s, a, G)\n",
    "    '''\n",
    "    \n",
    "    state_action_returns = []\n",
    "    \n",
    "    state, action, reward = state_action_rewards.pop()\n",
    "    G = reward\n",
    "    \n",
    "    for state, action, reward in reversed(state_action_rewards):\n",
    "        state_action_returns.append((state, action , G))\n",
    "        \n",
    "        G = reward + GAMMA * G\n",
    "    \n",
    "    state_action_returns.reverse()\n",
    "    return state_action_returns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_EPISODES = 1000\n",
    "EPSILON = 0.2\n",
    "GAMMA = 0.99\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 192,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "policy = {}\n",
    "Q = {}\n",
    "returns = {}\n",
    "g = Gridworld(DEFAULT)\n",
    "\n",
    "# Set up variables\n",
    "for state in g.available_states:\n",
    "    policy[state] = random.choice([ action for action, reward in g.get_valid_moves(state) ])\n",
    "    \n",
    "    Q[state] = {}\n",
    "    for action, reward in g.get_valid_moves(state):\n",
    "        Q[state][action] = 0\n",
    "        returns[state, action] = []\n",
    "\n",
    "# Run for x Episodes\n",
    "for i in range(1000):\n",
    "    state_action_rewards = play_game(policy)\n",
    "    state_action_returns = calculate_returns(state_action_rewards)\n",
    "    \n",
    "    # Add new state action to returns\n",
    "    for state, action, G in state_action_returns:\n",
    "        returns[state, action].append(G)\n",
    "    \n",
    "    # Update Q table\n",
    "    for state in g.available_states:\n",
    "        for action, reward in g.get_valid_moves(state):\n",
    "            if (len(returns[state, action]) > 0):\n",
    "                Q[state][action] = sum(returns[state, action]) / len(returns[state, action])\n",
    "    \n",
    "    # Update policy\n",
    "    for state in g.available_states:\n",
    "        policy[state] = max(Q[state], key=Q[state].get)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 193,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{(0, 0): {'UP': -0.5147121789103535,\n",
       "  'RIGHT': 0.16527512783723738,\n",
       "  'DOWN': -0.12284343844970244,\n",
       "  'LEFT': -0.501316441956495},\n",
       " (0, 1): {'UP': -0.24572017086164472,\n",
       "  'RIGHT': 0.48005673441673524,\n",
       "  'LEFT': -0.0468073990350339},\n",
       " (0, 2): {'UP': -0.4651314378919191,\n",
       "  'RIGHT': 0.6795883252569955,\n",
       "  'LEFT': 0.3339049480603574},\n",
       " (0, 3): {'UP': 0.2844132689966502,\n",
       "  'RIGHT': 0.882077726177756,\n",
       "  'DOWN': 0.3090594929968574,\n",
       "  'LEFT': 0.5133695358656334},\n",
       " (1, 0): {'UP': 0.053413383631128686,\n",
       "  'DOWN': -0.48114131787488207,\n",
       "  'LEFT': -0.7385887675980373},\n",
       " (1, 3): {'UP': 0.5558970839230097,\n",
       "  'RIGHT': -0.9031763715625,\n",
       "  'DOWN': -0.8737793282876933},\n",
       " (2, 0): {'UP': -0.1396153585062657,\n",
       "  'RIGHT': -1.2522548360947725,\n",
       "  'DOWN': -1.2141019626877751,\n",
       "  'LEFT': -1.0670681424914923},\n",
       " (2, 1): {'RIGHT': -1.4758132298191304,\n",
       "  'DOWN': -2.6382643765723524,\n",
       "  'LEFT': -0.8801771238215549},\n",
       " (2, 2): {'RIGHT': -1.3094026287941996,\n",
       "  'DOWN': -2.33418860550197,\n",
       "  'LEFT': -1.509882733587003},\n",
       " (2, 3): {'UP': -1.5750801818454598,\n",
       "  'RIGHT': -1.0027384649411772,\n",
       "  'DOWN': -1.7675952728751432,\n",
       "  'LEFT': -1.724485187208308},\n",
       " (2, 4): {'UP': -1.018774418604651,\n",
       "  'RIGHT': -1.08985,\n",
       "  'DOWN': -3.8356026398846836,\n",
       "  'LEFT': -1.1194000000000002}}"
      ]
     },
     "execution_count": 193,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
